""" Value function computation with simulation.

"""
import numpy as np
from dask.distributed import Client, progress
import pandas as pd
from numba import njit
try:
    from src.interpolation.splines import UCGrid, nodes, eval_linear
except:
    from ..interpolation.splines import UCGrid, nodes, eval_linear
from src.models_new import states 


@njit
def do_simulation(state, sim_length, seed, price_new, price_extend,
                  prob_match, prob_extend, const, r, sigma, fortnight):
    """ Forward simulate a stream of prices. Note that since I use njit
    the code is written quite simply (no dictionaries etc).

    Args:
        state (np.array): initial state to start the simulation from
        sim_length (int): how many periods to do the simulation for
        seed (int): numpy seed for sampling
        price_new (function): get the price of a new contract given the
            state and duration
        price_extend (function): get the price of extending a contract
        prob_match (function): get the prob. of matching a new contract
        prob_extend (function): get the prob. of extending a contract
        const (np.array): constant in the AR(1) state evolution
        r (np.array): slope in the AR(1) state evolution
        sigma (float): std. deviation of the errors for the gas price
            evolution

    Returns:
        prices_list (list): list of the stream of prices
        states_list (list): list of the stream of states
    """
    np.random.seed(seed)
    contract_periods = 0
    prices_list = list()
    states_list = list()

    for _ in range(sim_length):
        # match or countdown
        if contract_periods == 0:
            # prob match generates (4 x 1) vector
            prob_match_t = prob_match(state)
            contract_duration = states.rand_choice_nb(
                arr=np.array([0, 2, 3, 4]), prob=prob_match_t
            )
            if fortnight == True:
                contract_duration = contract_duration * 2
            contract_periods = contract_duration
            if contract_periods >= 1:
                price_t = price_new(state, contract_duration)
            else:
                price_t = 0.0

        # record everything (just update same price from previous if periods >= 1)
        prices_list.append(price_t)
        states_list.append(state)

        # potentially extend the contract (all payoffs are received in the next period)
        if contract_periods == 1:
            prob_extend_t = prob_extend(state, contract_duration)
            extend_t = np.random.binomial(n=1, p=prob_extend_t)
            if extend_t == 1:
                price_t = price_extend(state, contract_duration)
                # Note: use the +1 here since the contract will start the next period
                contract_periods = contract_duration + 1
            else:
                price_t = 0.0

        # update everything
        state = states.next_state(state, const, r, sigma)
        if contract_periods >= 1:
            contract_periods += -1

    return prices_list, states_list


def do_simulation_over_seeds(
        state, sim_length, price_new, price_extend,
        prob_match, prob_extend, const, r, sigma, seeds, fortnight):
    prices_list = []
    states_list = []
    for seed in seeds:
        output = do_simulation(
            state=state,
            sim_length=sim_length,
            seed=seed,
            price_new=price_new,
            price_extend=price_extend,
            prob_match=prob_match,
            prob_extend=prob_extend,
            const=const,
            r=r,
            sigma=sigma,
            fortnight=fortnight
        )
        prices_list.append(output[0])
        states_list.append(output[1])
    return prices_list, states_list


def get_value_at_state(state_initial, first_stage, beta, sim_length, seeds, fortnight,
                       verbose=False):
    """ Get the value function evaluated at a particular state,
    evaluated over many forward simulations of the value function
    (one for each seed in seeds).

    Args:
        state_initial (np.array): state where the value function is
            evaluated.
        first_stage (dict): information from the first stage
            (e.g. prob of matching)
        beta (float): discount factor
        sim_length (int): length of the simulation
        seeds (list): list of the different seeds for the simulation

    Returns:
        value (float): value function at the state
    """

    prices_list, states_list = do_simulation_over_seeds(
            state=state_initial,
            sim_length=sim_length,
            seeds=seeds,
            fortnight=fortnight,
            **first_stage
        )

    prices_by_sim = dict(zip(seeds, prices_list))
    states_by_sim = dict(zip(seeds, states_list))
    betas = pd.Series([beta ** s for s in range(sim_length)])
    df_prices_by_sim = pd.DataFrame(prices_by_sim).T
    df_prices_by_sim_with_beta = (df_prices_by_sim * betas).T
    df_value_by_sim = df_prices_by_sim_with_beta.sum(axis=0)
    value = df_value_by_sim.mean()
    if verbose:
        df_states_by_sim = pd.DataFrame(states_by_sim).T
        return value, df_prices_by_sim, df_prices_by_sim_with_beta, df_states_by_sim
    else:
        return value


def build_value_functions(g_grid, n_grid, first_stage, beta, sim_length, seeds, fortnight,
                          options):
    """ Get the interpolated value function by iterating over the
        interpolation grid.

    Note: this uses parallelization using dask.

    Args:
        g_grid (np.array): grid for gas price
        n_grid (np.array): grid for number of available rigs
        first_stage (dict): information from the first stage
            (e.g. prob of matching)
        beta (float): discount factor
        sim_length (int): length of simulation
        seeds (list): seeds to use for each simulation
        options (dict): options including parallelization options

    Returns:
        Interpolator: fitted scipy interpolator
    """
    # setup interpolator
    uniform_grid = UCGrid(
        (2, 15, g_grid),
        (2, 35, n_grid),
        (2, 35, n_grid),
        (2, 35, n_grid)
    )
    nodes_grid = nodes(uniform_grid)
    nodes_list = [nodes_grid[i] for i in range(len(nodes_grid))]
    values = dict()

    # do the computation (in parallel)
    with Client(
        threads_per_worker=options['threads_per_worker'],
        n_workers=options['n_workers']
    ) as client:
        futures = client.map(
            get_value_at_state,
            nodes_list,
            beta=beta,
            sim_length=sim_length,
            seeds=seeds,
            fortnight=fortnight,
            first_stage=first_stage
        )
        progress(futures)
        for n, i in enumerate(nodes_list):
            values[tuple(i)] = futures[n].result()
        values_to_interp = np.array(list(values.values())).reshape(
            g_grid, n_grid, n_grid, n_grid)

    return uniform_grid, values_to_interp, values


def build_ols_predictor(coefs):

    @njit
    def ols_predictor(state, tau):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]
        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
        )
        return x
    return ols_predictor


def build_mnl_predictor(coefs):

    @njit
    def mnl_predictor(state):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = np.array([0.0, 0.0, 0.0, 0.0])
        for i in [0, 1, 2, 3]:
            x[i] = (
                coefs[i, 0]
                + coefs[i, 1] * g
                + coefs[i, 2] * n_l
                + coefs[i, 3] * n_m
                + coefs[i, 4] * n_h
                + coefs[i, 5] * g * g
                + coefs[i, 6] * g * n_l
                + coefs[i, 7] * g * n_m
                + coefs[i, 8] * g * n_h
                + coefs[i, 9] * n_l * n_l
                + coefs[i, 10] * n_l * n_m
                + coefs[i, 11] * n_l * n_h
                + coefs[i, 12] * n_m * n_m
                + coefs[i, 13] * n_m * n_h
                + coefs[i, 14] * n_h * n_h
            )

        # NORMALIZATION
        #x = x - np.max(x)

        denom = np.exp(x[0]) + np.exp(x[1]) + np.exp(x[2]) + np.exp(x[3])

        return np.array([
            np.exp(x[0]) / denom,
            np.exp(x[1]) / denom,
            np.exp(x[2]) / denom,
            np.exp(x[3]) / denom
        ])
    return mnl_predictor


def build_prob_extension_predictor(coefs):

    @njit
    def prob_extension_predictor(state, tau):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
        )

        # NORMALIZATION
        #x = x - np.max(x)

        return np.exp(x) / (1 + np.exp(x))

    return prob_extension_predictor


def build_prob_extension_predictor_with_mri(coefs):

    @njit
    def prob_extension_predictor_with_mri(state, tau, mri):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
            + coefs[16] * mri
        )

        # NORMALIZATION
        #x = x - np.max(x)

        return np.exp(x) / (1 + np.exp(x))

    return prob_extension_predictor_with_mri


def build_price_extension_predictor(coefs):

    @njit
    def price_extension_predictor(state, tau):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
        )

        # NORMALIZATION
        #x = x - np.max(x)

        return x

    return price_extension_predictor